#include "StdAfx.h"
#include "tallier.h"

#ifdef RSA_MOD_SIZE
#undef RSA_MOD_SIZE
#define RSA_MOD_SIZE 10
#endif

#include "../Lab5_EDS_RSA/RSA_System.cpp"
#include <algorithm>
using std::transform;
using std::for_each;
using std::cout;

tallier::tallier(questionnarie questions)
	:RSA_System()
{
	_questions = questions;
	_elections.resize(_questions.size());
}


const messages tallier::get_plain_messages(const messages& blinded_messages, const keys& unblinded_keys) const
{
	assert(blinded_messages.size()==unblinded_keys.size());

	uint128 _d = get_RSA_exp();
	uint128 n = get_RSA_mod();

	messages plain_messages(blinded_messages.size());
	transform(blinded_messages.begin(), blinded_messages.end(), unblinded_keys.begin(), plain_messages.begin(), [&](const messages_set& ms_set, const key& k) -> messages_set
	{
		messages_set plain_set(ms_set.size());
		transform(ms_set.begin(), ms_set.end(), plain_set.begin(), [&](const message& b) -> message
		{
			return uint128::mult_m(uint128::pow_m(k, _d, n), b, n);
		});
		return plain_set;
	});


	printf("Plain Messages:\n");
	for_each(plain_messages.begin(), plain_messages.end(), [&](messages_set ms)
	{
		for_each(ms.begin(), ms.end(), [&](uint128 m)
		{
			cout<<"m = "<<m - SALT<<"\t";
		});
		printf("\n");
	});
	printf("\n\n");

	return plain_messages;
}

const messages_set tallier::get_blind_signed_messages_set(const messages_set& blinded_messages_set) const
{
	messages_set signed_messages_set(blinded_messages_set.size());
	uint128 n = get_RSA_mod();
	transform(blinded_messages_set.begin(), blinded_messages_set.end(),signed_messages_set.begin(), [&](const message& h) -> message
	{
		return uint128::pow_m(h, this->d, n);
	});

	cout<<"Signed blind messages Q:\n";
	for_each(signed_messages_set.begin(), signed_messages_set.end(), [&](uint128 m)
	{
		cout<<"m = "<<m<<"\t";
	});
	cout<<"\n";

	return signed_messages_set;
}

bool tallier::check_vote(signed_message msg)
{
	uint128 _d = get_RSA_exp();
	uint128 n = get_RSA_mod();

	bool correct = uint128::sub_m(uint128::pow_m(msg.signature, _d, n), msg.message, n) == 0;
	cout<<"The voter is "<< (correct ? "right! The vote is taken!": "wrong (. Sorry, you are malicious.")<<"\n\n";

	if(correct)
	{
		for(size_t i = 0; i < _elections.size(); ++i)
		{
		
			//if( ((msg.message - SALT==0) & (1UL << i)) != 0)
			// answer to the i-th question is YES
			if( ((msg.message - SALT) & (1UL << i)) != 0)
			{
				++_elections[i];
			}
		}
	}
	return correct;
}